package com.utils.dt;


import com.las.utils.JsonUtils;

import java.math.BigDecimal;
import java.util.*;

/**
 * 策略树算法
 *
 * @author dullwolf
 */
public class DecisionTree {


    public TreeNode createTree(List<Data> dataList) {
        //创建当前节点
        TreeNode<String, String, String> nowTreeNode = new TreeNode<String, String, String>();
        //当前节点的各个分支节点
        Map<String, TreeNode> featureDecisionMap = new HashMap<String, TreeNode>();

        //统计当前样本集中所有的分类结果
        Set<String> resultSet = new HashSet<String>();
        for (Data data :
                dataList) {
            resultSet.add(data.getResult());
        }

        //如果当前样本集只有一种类别，则表示不用分类了，返回当前节点
        if (resultSet.size() == 1) {
            String resultClassify = resultSet.iterator().next();

            nowTreeNode.setResultNode(resultClassify);

            return nowTreeNode;
        }

        //如果数据集中特征为空，则选择整个集合中出现次数最多的分类，作为分类结果
        if (dataList.get(0).getFeature().size() == 0) {
            Map<String, Integer> countMap = new HashMap<String, Integer>();
            for (Data data : dataList) {
                countMap.merge(data.getResult(), 1, (a, b) -> a + b);
            }

            String tmpResult = "";
            Integer tmpNum = 0;
            for (String res :
                    countMap.keySet()) {
                if (countMap.get(res) > tmpNum) {
                    tmpNum = countMap.get(res);
                    tmpResult = res;
                }
            }

            nowTreeNode.setResultNode(tmpResult);

            return nowTreeNode;
        }

        //寻找当前最优分类
        String bestLabel = chooseBestFeatureToSplit(dataList);

        //提取最优特征的所有可能值
        Set<String> bestLabelInfoSet = new HashSet<String>();
        for (Data data :
                dataList) {
            bestLabelInfoSet.add(data.getFeature().get(bestLabel));
        }

        //使用最优特征的各个特征值进行分类
        for (String labelInfo : bestLabelInfoSet) {
            List<Data> branchDataList = splitDataList(dataList, bestLabel, labelInfo);

            //最优特征下该特征值的节点
            TreeNode branchTreeNode = createTree(branchDataList);
            featureDecisionMap.put(labelInfo, branchTreeNode);
        }

        nowTreeNode.setDecisionNode(bestLabel, featureDecisionMap);

        return nowTreeNode;
    }

    /**
     * 计算传入数据集中的最优分类特征
     *
     * @param dataList
     * @return int 最优分类特征的描述
     */
    public String chooseBestFeatureToSplit(List<Data> dataList) {
        //目前数据集中的特征集合
        Set<String> futureSet = dataList.get(0).getFeature().keySet();

        //未分类时的熵
        BigDecimal baseEntropy = calcShannonEnt(dataList);

        //熵差
        BigDecimal bestInfoGain = new BigDecimal("0");
        //最优特征
        String bestFeature = "";

        //按照各特征分类
        for (String future :
                futureSet) {
            //该特征分类后的熵
            BigDecimal futureEntropy = new BigDecimal("0");

            //该特征的所有特征值去重集合
            Set<String> futureInfoSet = new HashSet<String>();
            for (Data data :
                    dataList) {
                futureInfoSet.add(data.getFeature().get(future));
            }

            //按照该特征的特征值一一分类
            for (String futureInfo :
                    futureInfoSet) {
                List<Data> splitResultDataList = splitDataList(dataList, future, futureInfo);

                //分类后样本数占总样本数的比例
                BigDecimal tmpProb = new BigDecimal(splitResultDataList.size() + "").divide(new BigDecimal(dataList.size() + ""), 5, BigDecimal.ROUND_HALF_DOWN);

                //所占比例乘以分类后的样本熵，然后再进行熵的累加
                futureEntropy = futureEntropy.add(tmpProb.multiply(calcShannonEnt(splitResultDataList)));
            }

            BigDecimal subEntropy = baseEntropy.subtract(futureEntropy);

            if (subEntropy.compareTo(bestInfoGain) >= 0) {
                bestInfoGain = subEntropy;
                bestFeature = future;
            }
        }

        return bestFeature;
    }

    /**
     * 计算传入样本集的熵值
     *
     * @param dataList 样本集
     */
    public BigDecimal calcShannonEnt(List<Data> dataList) {
        //样本总数
        BigDecimal sumEntries = new BigDecimal(dataList.size() + "");
        //香农熵
        BigDecimal shannonEnt = new BigDecimal("0");
        //统计各个分类结果的样本数量
        Map<String, Integer> resultCountMap = new HashMap<String, Integer>();
        for (Data data : dataList) {
            resultCountMap.merge(data.getResult(), 1, (a, b) -> a + b);
        }

        for (String resultCountKey :
                resultCountMap.keySet()) {
            BigDecimal resultCountValue = new BigDecimal(resultCountMap.get(resultCountKey).toString());

            BigDecimal prob = resultCountValue.divide(sumEntries, 5, BigDecimal.ROUND_HALF_DOWN);
            shannonEnt = shannonEnt.subtract(prob.multiply(new BigDecimal(Math.log(prob.doubleValue()) / Math.log(2) + "")));
        }

        return shannonEnt;
    }

    /**
     * 根据某个特征的特征值，进行样本数据的划分，将划分后的样本数据集返回
     *
     * @param dataList 待划分的样本数据集
     * @param future   筛选的特征依据
     * @param info     筛选的特征值依据
     */
    public List<Data> splitDataList(List<Data> dataList, String future, String info) {
        List<Data> resultDataList = new ArrayList<Data>();
        for (Data data :
                dataList) {
            if (data.getFeature().get(future).equals(info)) {
                Data newData = (Data) data.clone();
                newData.getFeature().remove(future);
                resultDataList.add(newData);
            }
        }

        return resultDataList;
    }

    /**
     * L:每一个特征的描述信息的类型
     * F:特征的类型
     * R:最终分类结果的类型
     */
    public class TreeNode<L, F, R> {
        /**
         * 该节点的最优特征的描述信息
         */
        private L label;

        /**
         * 根据不同的特征作出响应的决定。
         * K为特征值，V为该特征值作出的决策节点
         */
        private Map<F, TreeNode> featureDecisionMap;

        /**
         * 是否为最终分类节点
         */
        private boolean isFinal;

        /**
         * 最终分类结果信息
         */
        private R resultClassify;

        /**
         * 设置叶子节点
         *
         * @param resultClassify 最终分类结果
         */
        public void setResultNode(R resultClassify) {
            this.isFinal = true;
            this.resultClassify = resultClassify;
        }

        /**
         * 设置分支节点
         *
         * @param label              当前分支节点的描述信息（特征）
         * @param featureDecisionMap 当前分支节点的各个特征值，与其对应的子节点
         */
        public void setDecisionNode(L label, Map<F, TreeNode> featureDecisionMap) {
            this.isFinal = false;
            this.label = label;
            this.featureDecisionMap = featureDecisionMap;
        }

        /**
         * 展示当前节点的树结构
         */
        public String showTree() {
            HashMap<String, String> treeMap = new HashMap<String, String>();
            if (isFinal) {
                String key = "result";
                R value = resultClassify;
                treeMap.put(key, value.toString());
            } else {
                String key = label.toString();
                HashMap<F, String> showFutureMap = new HashMap<F, String>();
                for (F f : featureDecisionMap.keySet()) {
                    showFutureMap.put(f, featureDecisionMap.get(f).showTree());
                }
                String value = showFutureMap.toString();

                treeMap.put(key, value);
            }

            return treeMap.toString();
        }

        /**
         * 展示当前节点的树结构
         */
        public String showTreeJsonStr() {
            HashMap<String, Map<String,Object>> treeMap = new HashMap<>();
            if (isFinal) {
                String key = "result";
                R value = resultClassify;
                treeMap.put(key, JsonUtils.getMapByObject(value));
            } else {
                String key = label.toString();
                HashMap<F, Object> showFutureMap = new HashMap<>();
                for (F f : featureDecisionMap.keySet()) {
                    showFutureMap.put(f, featureDecisionMap.get(f));
                }

                treeMap.put(key, JsonUtils.getMapByObject(showFutureMap));
            }

            return JsonUtils.getJsonString(treeMap);
        }

        public L getLabel() {
            return label;
        }

        public Map<F, TreeNode> getFeatureDecisionMap() {
            return featureDecisionMap;
        }

        public R getResultClassify() {
            return resultClassify;
        }

        public boolean getFinal() {
            return isFinal;
        }
    }
}